import torch
import torch.nn as nn
from importlib import import_module
from typing import Optional
from .utils import QuantMethod, QuantDtype


class IPEXWeightOnlyQuantizedLinear(nn.Module):
    r"""
    A weight-only quantized (WOQ) linear module with floating point tensor as inputs and outputs.
    Weight is dequantized at runtime for computation.
    """

    def __init__(self, woq_linear_impl):
        super().__init__()
        self.woq_linear_impl = woq_linear_impl

    module_mapping = {
        "cpu": "intel_extension_for_pytorch.nn.modules.weight_only_quantization",
    }

    impl_name = "WeightOnlyQuantizedLinear"

    @classmethod
    def from_weight(
        cls,
        qweight: torch.Tensor,
        scales: torch.Tensor,
        zero_points: torch.Tensor,
        in_features: int,
        out_features: int,
        qconfig=None,
        bias: Optional[torch.Tensor] = None,
        group_size: int = -1,
        g_idx: Optional[torch.Tensor] = None,
        quant_method: QuantMethod = QuantMethod.GPTQ_GEMM,
        dtype: QuantDtype = QuantDtype.INT4,
        **kwargs
    ):
        r"""Create a weight-only quantized module from weight

        Args:
            qweight (Tensor): tensor in int32 dtype and contains actually int4 data
            scales (Tensor): scales for qweight
            zero_points (Tensor): zero points for qweight
            in_features (int): size of each input sample
            out_features (int): size of each output sample
            qconfig (object): Defining the IPEX quantization recipe for Weight only quantization.
                Default value is ``None``.
            bias (Tensor or None): bias for linear
            group_size: Group size for weight quantization
            g_idx: Indices of groups for each input channel of weight. Generated by
                GPTQ with act-order.
            quant_method: Quantization method, such as GPTQ, AWQ, ...
            dtype (QuantDtype): quantization data type

        """
        device_type = qweight.device.type
        assert device_type in {"cpu"}, "Device type not supported."
        woq_linear_impl_cls = getattr(
            import_module(cls.module_mapping[device_type]), cls.impl_name
        )
        woq_linear_impl = woq_linear_impl_cls.from_weight(
            qweight,
            scales,
            zero_points,
            in_features,
            out_features,
            qconfig,
            bias,
            group_size,
            g_idx,
            quant_method,
            dtype,
            **kwargs
        )
        return cls(woq_linear_impl)

    def forward(self, x, **kwargs):
        return self.woq_linear_impl(x, **kwargs)
